Skip to content

Replace jnp.interp with interpax.interp1d in thermo.py#24

Merged
cgiovanetti merged 1 commit intomainfrom
siddharth/interpax-thermo
Jan 28, 2026
Merged

Replace jnp.interp with interpax.interp1d in thermo.py#24
cgiovanetti merged 1 commit intomainfrom
siddharth/interpax-thermo

Conversation

@smsharma
Copy link
Collaborator

@smsharma smsharma commented Jan 24, 2026

Summary

Follow-up to #16 as suggested by @cgiovanetti. This PR replaces jnp.interp with interpax.interp1d in thermo.py for consistency with abundances.py and better performance.

Changes:

  • Replace 6 jnp.interp calls with interpax.interp1d
  • Flip QED correction tables at load time (instead of at runtime) since interpax requires monotonically increasing x coordinates
  • Use constant extrapolation values (0.0, 1.0) for collision factor tables (matching original behavior)

Functions updated:

  • rho_EM_std: QED corrections
  • p_EM_std: QED correction
  • rho_plus_p_EM_std: QED correction
  • G_nue_with_me: Collision factor interpolation
  • G_numt_with_me: Collision factor interpolation

Caveat

For the collision factor interpolations (G_nue_with_me, G_numt_with_me), the original code used left=f_tab[0,1], right=f_tab[-1,1] to extrapolate with boundary values. Since interpax.interp1d requires the extrap parameter to be concrete values (not traced JAX arrays), I hardcoded extrap=(0.0, 1.0) based on the current data files:

f_nue_ann  - first: 0.0, last: 1.0
f_nue_scat - first: 0.0, last: 1.0
f_numu_ann - first: 0.0, last: 1.0
f_numu_scat - first: 0.0, last: 1.0

If these data files are ever updated with different boundary values, the hardcoded extrapolation values would need to be updated accordingly.

Test plan

  • pytest pytest/test_abundances.py passes
  • Reference BBN abundances match expected values

🤖 Generated with Claude Code

Follow-up to PR #16 which replaced jnp.interp with interpax in abundances.py.
This change applies the same improvement to thermo.py for consistency and
better performance.

Changes:
- Import interpax module
- Flip QED correction tables at load time (instead of at each call) for
  monotonically increasing x coordinates required by interpax
- Replace 6 jnp.interp calls with interpax.interp1d:
  - rho_EM_std: 2 calls for QED corrections
  - p_EM_std: 1 call for QED correction
  - rho_plus_p_EM_std: 1 call for QED correction
  - G_nue_with_me: 1 call for collision factor interpolation
  - G_numt_with_me: 1 call for collision factor interpolation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@smsharma smsharma requested a review from cgiovanetti January 24, 2026 04:49
Copy link
Owner

@cgiovanetti cgiovanetti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked timing is also improved with this change

@cgiovanetti cgiovanetti merged commit da88e4d into main Jan 28, 2026
1 check passed
@cgiovanetti cgiovanetti deleted the siddharth/interpax-thermo branch January 28, 2026 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants